{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/huangzhitong/anaconda3/envs/unicolor/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "from PIL import Image\n",
    "import cv2\n",
    "import os\n",
    "from tqdm import tqdm\n",
    "import torchvision.transforms as T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "original_path = '../results/old_coco_259999/original'\n",
    "\n",
    "test_paths = [\n",
    "    '../results/old_coco_259999/colorized',\n",
    "    '../results/new_coco_299999/colorized'\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# FID score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:04<00:00,  1.55it/s]\n",
      "100%|██████████| 100/100 [01:06<00:00,  1.51it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "../results/old_coco_259999/original \n",
      " ../results/old_coco_259999/colorized \n",
      " FID score: 7.809005101211142\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:08<00:00,  1.46it/s]\n",
      "100%|██████████| 100/100 [01:07<00:00,  1.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "../results/old_coco_259999/original \n",
      " ../results/new_coco_299999/colorized \n",
      " FID score: 8.265271639647949\n"
     ]
    }
   ],
   "source": [
    "from pytorch_fid import fid_score\n",
    "\n",
    "for path in test_paths:\n",
    "    fid = fid_score.calculate_fid_given_paths(paths=[original_path, path], batch_size=50, dims=2048, device='cuda:0')\n",
    "    print(original_path, '\\n', path, '\\n', f'FID score: {fid}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Colorfulness"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def image_colorfulness(image):\n",
    "\t# split the image into its respective RGB components\n",
    "\t(B, G, R) = cv2.split(image.astype(\"float\"))\n",
    "\n",
    "\t# compute rg = R - G\n",
    "\trg = np.absolute(R - G)\n",
    "\n",
    "\t# compute yb = 0.5 * (R + G) - B\n",
    "\tyb = np.absolute(0.5 * (R + G) - B)\n",
    "\n",
    "\t# compute the mean and standard deviation of both `rg` and `yb`\n",
    "\t(rbMean, rbStd) = (np.mean(rg), np.std(rg))\n",
    "\t(ybMean, ybStd) = (np.mean(yb), np.std(yb))\n",
    "\n",
    "\t# combine the mean and standard deviations\n",
    "\tstdRoot = np.sqrt((rbStd ** 2) + (ybStd ** 2))\n",
    "\tmeanRoot = np.sqrt((rbMean ** 2) + (ybMean ** 2))\n",
    "\n",
    "\t# derive the \"colorfulness\" metric and return it\n",
    "\treturn stdRoot + (0.3 * meanRoot)\n",
    "\n",
    "for path in test_paths:\n",
    "    files = os.listdir(path)\n",
    "    colorful = []\n",
    "    for i in tqdm(range(len(files))):\n",
    "        img = Image.open( os.path.join(path, files[i]) ).convert('RGB').resize([256, 256])\n",
    "        img = np.array(img)\n",
    "        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)\n",
    "        colorful.append( image_colorfulness(img) )\n",
    "\n",
    "    colorful = np.array(colorful)\n",
    "    colorful = np.mean(colorful)\n",
    "    print(path, '\\n', f'Colorfulness: {colorful}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PSNR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess(img, size):\n",
    "    if size == None:\n",
    "        size = [(img.size[1] // 16) * 16, (img.size[0] // 16) * 16]\n",
    "    transform = T.Compose([T.Resize(size), T.ToTensor()])\n",
    "    x = transform(img)\n",
    "    x = x.unsqueeze(0)\n",
    "    return x\n",
    "\n",
    "with torch.no_grad():\n",
    "    for path in test_paths:\n",
    "        src_files = os.listdir(original_path)\n",
    "        tgt_files = os.listdir(path)\n",
    "        psnrs = []\n",
    "        for i in tqdm(range(len(src_files))):\n",
    "            assert src_files[i].split('.')[0] == tgt_files[i].split('.')[0]\n",
    "            img0 = Image.open( os.path.join(original_path, src_files[i]) ).convert('RGB')\n",
    "            img1 = Image.open( os.path.join(path, tgt_files[i]) ).convert('RGB')\n",
    "            x0 = preprocess(img0, [256, 256])\n",
    "            x1 = preprocess(img1, [256, 256])\n",
    "            mse = torch.mean((x0 - x1) ** 2)\n",
    "            psnr = 10 * torch.log10(1.0 / mse)\n",
    "            psnrs.append(psnr.item())\n",
    "        psnrs = np.mean(psnrs)\n",
    "        print(original_path, '\\n', path, '\\n', f'PSNR: {psnrs}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# LPIPS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import lpips\n",
    "\n",
    "def preprocess(img, size):\n",
    "    if size == None:\n",
    "        size = [(img.size[1] // 16) * 16, (img.size[0] // 16) * 16]\n",
    "    transform = T.Compose([T.Resize(size), T.ToTensor()])\n",
    "    x = transform(img)\n",
    "    x = x * 2 - 1\n",
    "    x = x.unsqueeze(0)\n",
    "    return x\n",
    "\n",
    "loss_fn_alex = lpips.LPIPS(net='alex') # best forward scores\n",
    "#loss_fn_vgg = lpips.LPIPS(net='vgg') # closer to \"traditional\" perceptual loss, when used for optimization\n",
    "\n",
    "with torch.no_grad():\n",
    "    for path in test_paths:\n",
    "        src_files = os.listdir(original_path)\n",
    "        tgt_files = os.listdir(path)\n",
    "        distance = []\n",
    "        for i in tqdm(range(len(src_files))):\n",
    "            assert src_files[i].split('.')[0] == tgt_files[i].split('.')[0]\n",
    "            img0 = Image.open( os.path.join(original_path, src_files[i]) ).convert('RGB')\n",
    "            img1 = Image.open( os.path.join(path, tgt_files[i]) ).convert('RGB')\n",
    "            x0 = preprocess(img0, [256, 256])\n",
    "            x1 = preprocess(img1, [256, 256])\n",
    "            d = loss_fn_alex(x0, x1)\n",
    "            distance.append(d)\n",
    "        distance = torch.cat(distance, axis=0)\n",
    "        distance = distance.mean()\n",
    "        print(original_path, '\\n', path, '\\n', f'LPIPS: {distance}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SSIM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pytorch_msssim import ssim, SSIM\n",
    "\n",
    "def preprocess(img, size):\n",
    "    if size == None:\n",
    "        size = [(img.size[1] // 16) * 16, (img.size[0] // 16) * 16]\n",
    "    transform = T.Compose([T.Resize(size), T.ToTensor()])\n",
    "    x = transform(img)\n",
    "    x = x.unsqueeze(0)\n",
    "    return x\n",
    "\n",
    "ssim_module = SSIM(data_range=1, size_average=True, channel=3)\n",
    "batch_size = 100\n",
    "\n",
    "for path in test_paths:\n",
    "    src_files = os.listdir(original_path)\n",
    "    tgt_files = os.listdir(path)\n",
    "    ssim_scores = []\n",
    "    x0 = []\n",
    "    x1 = []\n",
    "    with torch.no_grad():\n",
    "        for i in tqdm(range(len(src_files))):\n",
    "            assert src_files[i].split('.')[0] == tgt_files[i].split('.')[0]\n",
    "            img0 = Image.open( os.path.join(original_path, src_files[i]) ).convert('RGB')\n",
    "            img1 = Image.open( os.path.join(path, tgt_files[i]) ).convert('RGB')\n",
    "            x0.append(preprocess(img0, [256, 256]))\n",
    "            x1.append(preprocess(img1, [256, 256]))\n",
    "\n",
    "            if (i+1) % batch_size == 0:\n",
    "                x0 = torch.cat(x0, dim=0)\n",
    "                x1 = torch.cat(x1, dim=0)\n",
    "                ssim_loss = ssim_module(x0, x1)\n",
    "                ssim_scores.append(ssim_loss.cpu().numpy())\n",
    "                x0 = []\n",
    "                x1 = []\n",
    "    ssim = np.mean(ssim_scores)\n",
    "    print(original_path, '\\n', path, '\\n', f'SSIM: {ssim}')\n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Contextual loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import contextual_loss as cl\n",
    "\n",
    "def preprocess(img, size):\n",
    "    if size == None:\n",
    "        size = [(img.size[1] // 16) * 16, (img.size[0] // 16) * 16]\n",
    "    transform = T.Compose([T.Resize(size), T.ToTensor()])\n",
    "    x = transform(img)\n",
    "    x = x.unsqueeze(0)\n",
    "    return x\n",
    "\n",
    "criterion = cl.ContextualLoss(use_vgg=False, loss_type='l1').cuda()\n",
    "\n",
    "for path in test_paths:\n",
    "    src_files = os.listdir(original_path)\n",
    "    tgt_files = os.listdir(path)\n",
    "    losses = []\n",
    "    with torch.no_grad():\n",
    "        for i in tqdm(range(len(src_files))):\n",
    "            assert src_files[i].split('.')[0] == tgt_files[i].split('.')[0]\n",
    "            img0 = Image.open( os.path.join(original_path, src_files[i]) ).convert('RGB')\n",
    "            img1 = Image.open( os.path.join(path, tgt_files[i]) ).convert('RGB')\n",
    "            x0 = preprocess(img0, [96, 96]).cuda()\n",
    "            x1 = preprocess(img1, [96, 96]).cuda()\n",
    "\n",
    "            loss = criterion(x0, x1)\n",
    "\n",
    "            losses.append(loss.cpu().numpy())\n",
    "            \n",
    "    print(original_path, '\\n', path, '\\n', f'Contextual loss: {np.mean(losses)}')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import contextual_loss as cl\n",
    "\n",
    "def preprocess(img, size):\n",
    "    if size == None:\n",
    "        size = [(img.size[1] // 16) * 16, (img.size[0] // 16) * 16]\n",
    "    transform = T.Compose([T.Resize(size), T.ToTensor()])\n",
    "    x = transform(img)\n",
    "    x = x.unsqueeze(0)\n",
    "    return x\n",
    "\n",
    "criterion = cl.ContextualLoss(use_vgg=False, loss_type='l2').cuda()\n",
    "\n",
    "for path in test_paths:\n",
    "    src_files = os.listdir(original_path)\n",
    "    tgt_files = os.listdir(path)\n",
    "    losses = []\n",
    "    with torch.no_grad():\n",
    "        for i in tqdm(range(len(src_files))):\n",
    "            assert src_files[i].split('.')[0] == tgt_files[i].split('.')[0]\n",
    "            img0 = Image.open( os.path.join(original_path, src_files[i]) ).convert('RGB')\n",
    "            img1 = Image.open( os.path.join(path, tgt_files[i]) ).convert('RGB')\n",
    "            x0 = preprocess(img0, [96, 96]).cuda()\n",
    "            x1 = preprocess(img1, [96, 96]).cuda()\n",
    "\n",
    "            loss = criterion(x0, x1)\n",
    "\n",
    "            losses.append(loss.cpu().numpy())\n",
    "            \n",
    "    print(original_path, '\\n', path, '\\n', f'Contextual loss: {np.mean(losses)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import contextual_loss as cl\n",
    "\n",
    "def preprocess(img, size):\n",
    "    if size == None:\n",
    "        size = [(img.size[1] // 16) * 16, (img.size[0] // 16) * 16]\n",
    "    transform = T.Compose([T.Resize(size), T.ToTensor()])\n",
    "    x = transform(img)\n",
    "    x = x.unsqueeze(0)\n",
    "    return x\n",
    "\n",
    "criterion = cl.ContextualLoss(use_vgg=False, loss_type='cosine').cuda()\n",
    "\n",
    "for path in test_paths:\n",
    "    src_files = os.listdir(original_path)\n",
    "    tgt_files = os.listdir(path)\n",
    "    losses = []\n",
    "    with torch.no_grad():\n",
    "        for i in tqdm(range(len(src_files))):\n",
    "            assert src_files[i].split('.')[0] == tgt_files[i].split('.')[0]\n",
    "            img0 = Image.open( os.path.join(original_path, src_files[i]) ).convert('RGB')\n",
    "            img1 = Image.open( os.path.join(path, tgt_files[i]) ).convert('RGB')\n",
    "            x0 = preprocess(img0, [96, 96]).cuda()\n",
    "            x1 = preprocess(img1, [96, 96]).cuda()\n",
    "\n",
    "            loss = criterion(x0, x1)\n",
    "\n",
    "            losses.append(loss.cpu().numpy())\n",
    "            \n",
    "    print(original_path, '\\n', path, '\\n', f'Contextual loss: {np.mean(losses)}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.13 ('unicolor')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "d2a626b74e299e13c7ec875d6c47b0f8e040df285fc56a1acf5a98c054a62f0a"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
